Seeing is believing

Using FlashTorch 🔦 to shine a light on what neural nets "see"


by Misa Ogura

Follow along @ https://bit.ly/2WzVzgu

Hello, I'm Misa 👋


  • Originally from Tokyo, now based in London
  • Cancer Cell Biologist, turned Software Engineer
  • Currently at BBC R&D
  • Co-founder of Women Driven Development
  • Women in Data Science London Ambassador

Feature visualisation


Introducing FlashTorch 🔦


  • Open source feature visualisation toolkit for neural nets in PyTorch

  • Supports torchvision models

  • Available to install via pip!

      $ pip install flashtorch

Image processing & CNN 101


Kernel & convolution


Kernel: a small matrix used for edge detection, blurring, sharpening, embossing, etc.

Convolution: an operation to calculate weighted sum of neibouring pixels

Examples of convolution: detecting edges


Typical CNN architecture


  • Kernels weights are learnt during the training

  • Extract features that are relevant to the task at hand

Feature visualisation technique

Saliency maps


Saliency


  • A subjective quality in human visual perception

  • Makes certain items stand out and grabs our attention

Saliency maps in computer vision: indications of the most “salient” regions

Saliemcy maps in CNNs


  • First introduced in 2013

  • Gradients of target class w.r.t. input image via backpropagation

  • Pixels with positive gradients: some intuition of attention

  • Avaialble via flashtorch.saliency API

FlashTorch demo 1

Visualising saliency maps with backpropagation


Install FlashTorch & load an image



$ pip install flashtorch

...
In [2]:
from flashtorch.utils import load_image

image = load_image('../../examples/images/great_grey_owl.jpg')

plt.imshow(image)
plt.title('Original image')
plt.axis('off');

Apply transformations


In [3]:
from flashtorch.utils import apply_transforms, denormalize, format_for_plotting

input_ = apply_transforms(image)

print(f'Before: {type(image)}')
print(f'After: {type(input_)}, {input_.shape}')

plt.imshow(format_for_plotting(denormalize(input_)))
plt.title('Input tensor')
plt.axis('off');
Before: <class 'PIL.Image.Image'>
After: <class 'torch.Tensor'>, torch.Size([1, 3, 224, 224])

Create a Backprop object with a pre-trained model


In [4]:
from flashtorch.saliency import Backprop

model = models.alexnet(pretrained=True)

backprop = Backprop(model)
  • Registers custom functions to model layers
  • Grabs intermidiate gradients out of the computational graph

To calculate gradiants:

Signature:

    backprop.calculate_gradients(input_, target_class=None, ...)

Calculate the gradients of target class w.r.t. input


In [5]:
from flashtorch.utils import ImageNetIndex 

imagenet = ImageNetIndex()
target_class = imagenet['great grey owl']

print(f'Traget class index: {target_class}')

# Ready to calculate gradients!

gradients = backprop.calculate_gradients(input_, target_class)

max_gradients = backprop.calculate_gradients(input_, target_class, take_max=True)

print(type(gradients), gradients.shape)
print(type(max_gradients), max_gradients.shape)
Traget class index: 24
<class 'torch.Tensor'> torch.Size([3, 224, 224])
<class 'torch.Tensor'> torch.Size([1, 224, 224])

Let's visualise gradients


In [6]:
from flashtorch.utils import visualize

visualize(input_, gradients, max_gradients)
Pixels where the animal is present have the strongest positive effects.

But it's quite noisy...

FlashTorch demo 2

Visualising saliency maps with guided backpropagation


Guided backpropagation


  • Additional guidance from the higher layers during backprop

  • Masks out neurons that had no effect or negative effects on the prediction

  • Preventing the flow of such gradients: less noise

In [7]:
guided_gradients = backprop.calculate_gradients(input_, target_class, guided=True)

max_guided_gradients = backprop.calculate_gradients(input_, target_class, take_max=True, guided=True)

Visualise guided gradients


In [8]:
visualize(input_, guided_gradients, max_guided_gradients)
Now that's much less noisy!

Pixels around the head and eyes have the strongest positive effects.

What about other birds?

What makes peacock a peacock?


In [10]:
visualize(input_, guided_gradients, max_guided_gradients)

... or a tucan?


In [12]:
visualize(input_, guided_gradients, max_guided_gradients)
Do you agree with the network? 🤖

FlashTorch demo 3

Gaining additional insights on transfer learning


Transfer learning


  • A model developed for a task is reused as a starting point for another task

  • Often used in computer vision & natural language processing tasks

  • Save compute & time resources

Building a flower classifier


From: Densenet model, pre-trained on ImageNet (1000 classes)

To: Flower classifier to recognise 102 species of flowers (dataset)

Pre-trained model - 0.1% test accuracy 😨


In [15]:
backprop = Backprop(pretrained_model)

guided_gradients = backprop.calculate_gradients(input_, class_index, guided=True)
guided_max_gradients = backprop.calculate_gradients(input_, class_index, take_max=True, guided=True)

visualize(input_, guided_gradients, guided_max_gradients)
/Users/misao/Projects/personal/flashtorch/flashtorch/saliency/backprop.py:94: UserWarning: The predicted class index 70 does notequal the target class index 96. Calculatingthe gradient w.r.t. the predicted class.
  'the gradient w.r.t. the predicted class.'
Trained model achieved 98.7% test accuracy.

But why is it better now?

Trained model - 98.7% test accuracy


In [16]:
backprop = Backprop(trained_model)

guided_gradients = backprop.calculate_gradients(input_, class_index, guided=True)
guided_max_gradients = backprop.calculate_gradients(input_, class_index, take_max=True, guided=True)

visualize(input_, guided_gradients, guided_max_gradients)
The trained model has learnt to shit focus on to the distinguising pattern.

Let's make neural nets more interpretable!


  • Saliency maps as indications of network's attention

  • flashtorch.saliency module to visualise saliency maps for CNNs in PyTorch

  • Asking why the network behaves in the way it does: a step forward from just looking at accuracy

Thank you!


🌡 Like it? Try out FlashTorch 🔦 on Google Colab

🙏 Comments, questions and feedback on the talk: Pull Request

🤝 General suggestions & contribution: Submit issues